{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Taxi-v3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "import logging\n",
    "import imp\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import pandas as pd\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "imp.reload(logging)\n",
    "logging.basicConfig(level=logging.INFO,\n",
    "        format='%(asctime)s [%(levelname)s] %(message)s',\n",
    "        stream=sys.stdout, datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23:10:32 [INFO] env: <TaxiEnv<Taxi-v3>>\n",
      "23:10:32 [INFO] action_space: Discrete(6)\n",
      "23:10:32 [INFO] observation_space: Discrete(500)\n",
      "23:10:32 [INFO] reward_range: (-inf, inf)\n",
      "23:10:32 [INFO] metadata: {'render.modes': ['human', 'ansi']}\n",
      "23:10:32 [INFO] _max_episode_steps: 200\n",
      "23:10:32 [INFO] _elapsed_steps: None\n",
      "23:10:32 [INFO] id: Taxi-v3\n",
      "23:10:32 [INFO] entry_point: gym.envs.toy_text:TaxiEnv\n",
      "23:10:32 [INFO] reward_threshold: 8\n",
      "23:10:32 [INFO] nondeterministic: False\n",
      "23:10:32 [INFO] max_episode_steps: 200\n",
      "23:10:32 [INFO] _kwargs: {}\n",
      "23:10:32 [INFO] _env_name: Taxi\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('Taxi-v3')\n",
    "env.seed(0)\n",
    "for key in vars(env):\n",
    "    logging.info('%s: %s', key, vars(env)[key])\n",
    "for key in vars(env.spec):\n",
    "    logging.info('%s: %s', key, vars(env.spec)[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23:10:32 [INFO] location of taxi = (0, 1)\n",
      "23:10:32 [INFO] location of passager = (0, 4)\n",
      "23:10:32 [INFO] location of destination = (4, 0)\n",
      "+---------+\n",
      "|R:\u001b[43m \u001b[0m| : :\u001b[34;1mG\u001b[0m|\n",
      "| : | : : |\n",
      "| : : : : |\n",
      "| | : | : |\n",
      "|\u001b[35mY\u001b[0m| : |B: |\n",
      "+---------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "state = env.reset()\n",
    "taxirow, taxicol, passloc, destidx = env.unwrapped.decode(state)\n",
    "logging.info('location of taxi = %s', (taxirow, taxicol))\n",
    "logging.info('location of passager = %s', env.unwrapped.locs[passloc])\n",
    "logging.info('location of destination = %s', env.unwrapped.locs[destidx])\n",
    "env.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(126, -1, False, {'prob': 1.0})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.step(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------+\n",
      "|R: | : :\u001b[34;1mG\u001b[0m|\n",
      "| :\u001b[43m \u001b[0m| : : |\n",
      "| : : : : |\n",
      "| | : | : |\n",
      "|\u001b[35mY\u001b[0m| : |B: |\n",
      "+---------+\n",
      "  (South)\n"
     ]
    }
   ],
   "source": [
    "env.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SARSA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SARSAAgent:\n",
    "    def __init__(self, env):\n",
    "        self.gamma = 0.9\n",
    "        self.learning_rate = 0.2\n",
    "        self.epsilon = 0.01\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "\n",
    "    def reset(self, mode=None):\n",
    "        self.mode = mode\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory = []\n",
    "\n",
    "    def step(self, observation, reward, done):\n",
    "        if self.mode == 'train' and np.random.uniform() < self.epsilon:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        else:\n",
    "            action = self.q[observation].argmax()\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory += [observation, reward, done, action]\n",
    "            if len(self.trajectory) >= 8:\n",
    "                self.learn()\n",
    "        return action\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "    def learn(self):\n",
    "        state, _, _, action, next_state, reward, done, next_action = \\\n",
    "                        self.trajectory[-8:]\n",
    "\n",
    "        target = reward + self.gamma * \\\n",
    "                self.q[next_state, next_action] * (1. - done)\n",
    "        td_error = target - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error\n",
    "\n",
    "\n",
    "agent = SARSAAgent(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23:10:32 [INFO] ==== train ====\n",
      "23:10:35 [INFO] ==== test ====\n",
      "23:10:35 [INFO] average episode reward = 8.03 ± 2.50\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhY0lEQVR4nO3de3xU1bn/8c+TEELCHbknIKh4AaQIERGrVUShYAXrq5baHvW0lZZyflZra+Xwa4/W0lLtOaf1VO3hZ1trTy2ltSpHUSvVVqsojYrc0SDRRBACyB0CSZ7fH7MTJsnkOpNMZvb3/XrNiz3P2rNnLZI8s2bttdc2d0dERMIlI9kVEBGR9qfkLyISQkr+IiIhpOQvIhJCSv4iIiHUKdkVaK6+ffv6sGHDkl0NEZGU8vrrr+9y93514ymT/IcNG0ZhYWGyqyEiklLM7L1YcQ37iIiEkJK/iEgIKfmLiISQkr+ISAgp+YuIhJCSv4hICCn5i4iEkJJ/DH9/ZxfFuw5R+tFhKqucLWUHeWXLrpry93cfbtZxjh6vZPu+I/XilVVOyZ7D9Y5z9HglO/YfbfSY5RWRY1ZUVvH+7sO8/t4entuwo1n1iXbkWCU/XL6R197dza6D5TyzbjsAB8sreOzNUr7x+9U1sZ37j/Ls+g956Z0yXny7DIBXinZRWRVZDvx4ZRWlHzXv/6Su45VVlOw5zNHjlXy4L3bb39t9CHfnvd2HANh9sJyHXt7KkWOVNWWxvPrubl4p2sXWXYc4erySpYUlHD1eyZrSvTy5ZlutfTd9uJ+SPYd57d3dAHyw9wjHK6vYc+gYrxTtYu/hYwAU7TzAK1t2sfQfJZTsOUx5RWXM9953+HjNa6IdOVbJHwpLauoc3a66+72waSe/WVnMhm37OVheweYPD1B2oDzm+1WrqnJ+s7KYD/ae+L07VF5B2YFytu87witFu9iwbX+jxwB4t+wgJXtO/Ez3Hz3O7oP133v7viO1/g+qf+dXbd3D5g8P1Nr3uQ07av2M9x89zmNvljZZl2g7DxxldcleFj61gW8sXU1llePu/P2dXbg7SwtLOFZRVes12/Ye4S8bT/yNvLf7EHcsW8/SwpKa2ItvlzX777pkz2H+FvwdFBbvYeP2/ew6WM6Bo8cBWLV1D/sOH4/5t1+y5zD7jhxn18FyDgY/l5HffYZ5j7xBVZWzc/9RDh+rYE3pXha/uKVeWxLJUmU9/4KCAm/ri7y+978b+J/X3mvRf/jwvl3Zuqv+H2+8zhjQnc07DnB2Xk/WfrCv2a8b0CObHfsbThBnDerBxu1N//EnS/fsTmR1ymDPofqJsy3cPGUEP1nxTrP2nTZqIM+s/zBm2YLpZ7Fw+cZEVq1Z+nXPbvIDoSWa+/v8vZmj+OHyTRw5HvvDL9l652bx0eHjdMowKqoaz3HXn38yv14Z8zqoDuGUfl15/taLW/16M3vd3QvqxZOV/M1sGvBTIBN40N0XNbZ/eyT/Ybc/1abHFxFpjS0/mE5mhrXqtQ0l/6Qs72BmmcB9wGVAKfAPM1vm7huSUZ/nNuygT9esZLy1iEiTjlVUkdM5M6HHTNbaPhOAInd/F8DMlgAzgXZP/lVVzo0Pa80gEem4Ep34IXknfPOAkqjnpUGsFjObY2aFZlZYVlbWJhVJjTMeLXP/58cl5Dh/ufUTPHfLRbVicy8+lRXfuIjpZw+sFV9/51RG5/Vo9Hi/vOHEN8/u2fX7HfMuObVme9Nd05gxZlCt8pXzJ/Py7ZMByMwwHryugOJFMyheNIPPjM8H4LKRA2r2zzBqyt9Z+MlG61bt3R9Mp3jRDD43YSgAT8y7oKbsotNrL4zYv3t2zXbxohlsumtag8d98zuXsfaOyyleNINHbjyvJv7CNy+udYzq9gE89rVJrJw/mc+fF6nLwqtGU7xoBq/OvzTme9x06QiKF82oef78rZ+o2V7xjRM/x4f++VwWTD+r1s8DYNzQXg3WP1ne+M5l/PMFw3jm5gv51tQzauLFi2bwsSG9ALhr5iiKF83g518Y3+ixBvfsUrP9uxsn8ujcSY3uf/JJuTXbb38/9u/PX4L/429PO5Przz8ZgKvOOZHKihfN4EdXn13vdfddO66mTi/ddkmt/YsXzWDj96bx9NcvBOArF53SaD1bK1k9/1iDV/XysLsvBhZDZMy/vSqS6qafPajBsovP6MdXP3EqLxft4r+eLwJg1YJL6ZmTxeYPD/CD5RsZ0juXc4b25tR+3Wq99pyhvfj2tDMBuP/z43lh8052HzzGWyV76ZrdiT98ZRKHjlXw9ocHuPbB13jptkvYfegYs+57melnD2TymQNYc8flVFU5h45VcsGi5wF4dO4kunfpxIj+3bjvhS1894qRdMnKZNbYPJ5as52hfXJ5MfgDORqcYLz76jFMiUr0i64ewy2Xnc7gXjnMvO9l3irZy1+/eeKPKiszg79/+xLue2ELv1v1Pp8Zn88tl53OpEXPc9es0awp2csdV44iIxhX/cFVo5l3yank985lylkDWLFxBx8/7SQuHzmA//v4OiCSuL//1EbefP8jALpkZVK8aAb3vVDEPc9uBqBXbha/n3M+vbt2rqnLpFP78tJtl5CRYeT1yuHT4/LYtD0yMyavVw4r50+mX7dsOmVG+ma3Xn4Gxyur+PQ5kQ+4gT27MLRPLu/vOcxNk09j3uTTeGrNdq4YM7jm/7NHl06c0q9brQ+DtXdcTmWV0yu3Mxef0R+A3375PNZ+sI9FT2/iaxefxpeDb8GPfPk8ntu4g1+9XAzALVNOp2/3zix4bB0//szHuPuZTew8UM7V4/KZdOpJ3PqHt7jo9H5cNKIvA3t2oWdOFhu27eeC0/oCUPrRYfp1z+YXf99KVmYGF43ox61/eKvW79dNk0/j2vNO5ubfv8mr7+7hk6MH0qdrZ/7tU6MAyDTjnmc3s/ymSFL809xJPP7mB8wKku200QNZeNVoLj1zALnZmXgVYPDOjgNkZBjjhvbmuQ07uPHhQs4c2J1euVlMOWsAfy8q449fncTAnl1YU7qXLz5UyPKbLmRY31xWl+wlv1cunTtl8OjcSdz2x7fYUnaIGWMGMSavJ6f268bLt0+u+WAZd3Jvpo4ayJVjB7Nl50EArikYwqRT+3Ln/67nhc1lVFY540/uzRvfuYzsThl0DTpCl0f9Pud0zuSsQT1YOX8y/buf+NBKpKSc8DWz84E73H1q8Hw+gLv/sKHXtNUJ36oq55R/XZ7w4ybaySfl8siNE2sS5t1Xj+G2R9fE3Ld40QwefOldvv9UZPbJw1+cUK/XCnDPs5uYOmogY/J7Nfreh8orOFZRVSuBJUJh8R565WZxWv/uMcvdncdXf8D0sweR3an5X3uPHq/kYHkFfbtlxywv/egw+b1zY5bFUlXlPPbmB8wcO5hOmRk8/uYHVFY5VwffNmLZ9OF+BvXMoWdO25xLOnysgqPHq+iT4J/J+7sPk9XJGNQzB3dn+PzlDDspl79+65KmX9xC+44cxwzG3PFngJoPqg/2HuG+F4r43pWjaj4A093OA0fpmZPVot/z5upQJ3yBfwAjzGw48AEwG7g2SXXp0H509dkUDOvDoJ5dyMmK/GJMP3sg15w7hGvOHcJXflPIs+t31JvC+eULT+FYZRV3P7OZvN45MY/9ralnNqsOXbM70TV2Ho1LwbA+jZabGVed03CCbUiXrEy6ZDX8R9SSxA+QkWG1Ev2sc+qNUNZz5sDGh8Dildu5E7mJzfsADI0a6jAzfnXDuYwa3DZtqf5grDudOa9XDj+4qv5QSTprq959Y5I51XM68BMiUz1/6e4LG9s/3Xv+M8cO5onV2+rFo7+2Q6SH0CunM5071e4R7TpYTsH3V9R6TVWV8+H+owzuFTv5i3QER45VcuhYw9/UJD4dreePuy8H2i3rujvfe3IDV52TV2uY49E3WnaFYVvJsNpnH9bccTmxPpcb6iH07ZbNujunUlF54gK1jAxT4pcOL6dzZpvMZpHGhWNADThyvJJfvVzMZ//71ZrY6+/t4Vt/jD1u3laiZ3dEq3viuUeXrBaPF3fL7kSvthgLEJG0E5rkH8vVD6xs9/cc3rcrN0waVr8gKvt3yQr1j0VE2kHK3MA9HQzsERmyqTvEU11216zR9O+ezdRRA+uVi4gkkpJ/O/pt1MU91cYO6cUNk4Yx/exB9U7iioi0FSX/dhRr+qG7N2vqoIhIIqmrmQTnRF1G39RysyIibSGUyf/nf9vSbss3R6+XUn1Nxac+NrgmFr1+iIhIewll8l/09KZ2e6+BPWPPy19351SmnDWAn84+p93qIiJSLXTJv73vPHRHsChVXd2yO/Hg9QVkhWTtEhHpWJR52lj/Hl1qVvyzGFM8RUSSITTJv3hX624wnkipcr9kEUl/oUn+0+99qd3f86xBbbuyo4hIa4Um+bena4M7L1XfmUrDPSLS0egirzbQPbtTvaWYRUQ6EvX8RURCSMm/lbQOj4iksrgymJl9xszWm1mVmRXUKZtvZkVmttnMpkbFx5vZ2qDsXkvRAfF+Lbjr0NA+kat49YEhIh1FvNloHfBp4MXooJmNJHJf3lHANOB+M6te1ewBYA4wInhMi7MOSXHTpac1XFjn4+yBL4xj8T+NT8p9OkVEYokr+bv7RnffHKNoJrDE3cvdfStQBEwws0FAD3df6ZFJ7w8Ds+KpQ7JcMWZw0zsFeuV25nKt0S8iHUhbjUPkASVRz0uDWF6wXTcek5nNMbNCMyssKytrk4q2VmoOVomIRDSZ/M1shZmti/GY2djLYsS8kXhM7r7Y3QvcvaBfv35NVbVdGcZ9145LdjVERFqlyXn+7j6lFcctBYZEPc8HtgXx/BjxlNMlK4OPj+gbs6wlJ4NFRJKhrS7yWgY8Ymb/AQwmcmJ3lbtXmtkBM5sIvAZcB/xXG9WhTUVPUsowqHK4a9Zoumd3qrVev4hIRxRX8jezq4gk737AU2a22t2nuvt6M1sKbAAqgHnuXr2W8lzgISAHeDp4pLTMDOPdhdOTXQ0RkWaLK/m7+2PAYw2ULQQWxogXAqPjed+OQid9RSRV6aojEZEQUvKPgzr+IpKqlPxFREJIyV9EJISU/OOQomvSiYgo+SeCbs0rIqlGyT8O6veLSKpS8hcRCSElfxGREFLyb4WvXXwqoCt8RSR1Kfm3QpeszFrPdb5XRFKNkn8cTKd8RSRFKfm3Qt2U75rrKSIpRsm/FarH+jXmLyKpSsk/AXSlr4ikGiX/BNCwj4ikmriSv5ndY2abzGyNmT1mZr2iyuabWZGZbTazqVHx8Wa2Nii711Kw25yCVRYRqSXenv9zwGh3HwO8DcwHMLORwGxgFDANuN/MqudHPgDMIXJf3xFBeUpTv19EUk1cyd/d/+zuFcHTV4H8YHsmsMTdy919K1AETDCzQUAPd1/pkbGSh4FZ8dShOX787OY2Oa6+AIhIqkrkmP8XOXEz9jygJKqsNIjlBdt14zGZ2RwzKzSzwrKyslZX7GcvFLX6tc2hzwARSTVN3sDdzFYAA2MULXD3J4J9FgAVwG+rXxZjf28kHpO7LwYWAxQUFHSY0ZW6Pf4OUzERkWZqMvm7+5TGys3seuAK4FI/Me2lFBgStVs+sC2I58eIp5TqK3t1ha+IpKp4Z/tMA74NXOnuh6OKlgGzzSzbzIYTObG7yt23AwfMbGIwy+c64Il46iAiIi0X75j/z4DuwHNmttrMfg7g7uuBpcAG4BlgnrtXBq+ZCzxI5CTwFk6cJ2g3s8YOTshxdMJXRFJVk8M+jXH30xopWwgsjBEvBEbH877xuuczH+Px1ZHRpqVfOZ9r/ntli15fb8xfg/4ikmJCeYVvVuaJZvft1rnFrz+pa+Q16viLSKqKq+cfRvd+7hw+NWZQsqshIhIXJf8WuvJjiTlfICKSTKEc9okWzzo9WuNHRFJV6JN/cwztk5vsKoiIJJSSfxzU7xeRVBX65K+1+EUkjEKf/EVEwij0yT++E74JrIiISDsKdfLPycpseidgwvA+bVwTEZH2Fdp5/kvmTGRon1zKK6qa3DezgS6+pnqKSKoKbc9/4iknMbhXTkKOlaHPABFJMaHt+SfKzVNGcNnIAcmuhohIiyj5x+nmKacnuwoiIi0W2mGfahqxEZEwCn3y1yVeIhJG8d7G8S4zWxPcxevPZjY4qmy+mRWZ2WYzmxoVH29ma4Oye01TZkRE2l28Pf973H2Mu48FngS+C2BmI4HZwChgGnC/mVVPqn8AmEPkvr4jgnIREWlHcSV/d98f9bQrJ0ZRZgJL3L3c3bcSuV/vBDMbBPRw95UeWVTnYWBWPHVoD67BIRFJM3HP9jGzhcB1wD7gkiCcB7watVtpEDsebNeNN3TsOUS+JTB06NB4qyoiIoEme/5mtsLM1sV4zARw9wXuPgT4LfAv1S+LcShvJB6Tuy929wJ3L+jXr1/TrWmF5pxwMM0JEpE002TP392nNPNYjwBPAf9GpEc/JKosH9gWxPNjxDu0jNDPiRKRdBPvbJ8RUU+vBDYF28uA2WaWbWbDiZzYXeXu24EDZjYxmOVzHfBEPHVoD52U/UUkzcQ75r/IzM4AqoD3gK8CuPt6M1sKbAAqgHnuXhm8Zi7wEJADPB08RESkHcWV/N396kbKFgILY8QLgdHxvG8iaR6PiISRxjOaQVM9RSTdhD75ax6PiIRR6JO/iEgYKfmLiISQkn8z6CIvEUk3Sv4iIiGk5C8iEkKhT/7NuZuApnqKSLoJffJ35XURCaHQJ38RkTBS8hcRCSElfxGREFLyb0D37LhvciYi0mGlffLfd+R4o+UNzvax6E1d5CUi6SXtk/+jr5c2vVMMlVWaBiQi6Svtk39rdco40dvXPH8RSTcJSf5m9k0zczPrGxWbb2ZFZrbZzKZGxceb2dqg7N7gdo5tprVpe+hJuQmth4hIRxJ38jezIcBlwPtRsZHAbGAUMA2438wyg+IHgDlE7us7IihPGl3kJSJhlIie/38Ct1G7kz0TWOLu5e6+FSgCJpjZIKCHu690dwceBmYloA4Nau3XCp3kFZF0FlfyN7MrgQ/c/a06RXlASdTz0iCWF2zXjTd0/DlmVmhmhWVlZa2qY1Md+7YddBIR6ZianMxuZiuAgTGKFgD/Clwe62UxYt5IPCZ3XwwsBigoKNAAjYhIgjSZ/N19Sqy4mZ0NDAfeCs7Z5gNvmNkEIj36IVG75wPbgnh+jHibcQ3qi4jU0+phH3df6+793X2Yuw8jktjHufuHwDJgtpllm9lwIid2V7n7duCAmU0MZvlcBzwRfzMaVtHK+foaDhKRdNYmaxi4+3ozWwpsACqAee5eGRTPBR4CcoCng0eb2bBtf6tep9wvIuksYck/6P1HP18ILIyxXyEwOlHv25R4B33mf/JMSj46nJC6iIh0FKG/wrfBKZ3BuM+E4X3asTYiIu0j7ZN/Uyd8m1q6QaeLRSQdpX3yby2N+YtIOlPyFxEJISX/ZtClAiKSbkKf/Jtaw0eJX0TSUdon/9bm7uiLvHTBl4ikm7RP/vFT119E0o+SfwPU2ReRdKbk3wSN+YtIOkr/5N/K5N3Gd5cUEUmq9E/+rXTpWf0BGNQrR71/EUk7bbKqZzqY+4lT+dy5Q+ndtXOyqyIiknDq+TfAzJT4RSRtpX3yb2rhNhGRMEr75F9X327Zya6CiEjSxZX8zewOM/vAzFYHj+lRZfPNrMjMNpvZ1Kj4eDNbG5Tdaykwrabj11BEpGUS0fP/T3cfGzyWA5jZSGA2MAqYBtxvZpnB/g8Ac4jc13dEUJ40SuwiEkZtNewzE1ji7uXuvhUoAiaY2SCgh7uv9MhdVh4GZrVRHYDEXKRVfYw5F50S/8FERDqARCT/fzGzNWb2SzPrHcTygJKofUqDWF6wXTcek5nNMbNCMyssKytLQFXra8mHw5A+uW1SBxGR9tZk8jezFWa2LsZjJpEhnFOBscB24N+rXxbjUN5IPCZ3X+zuBe5e0K9fv6aq2sAxWvUyEZG01uRFXu4+pTkHMrP/BzwZPC0FhkQV5wPbgnh+jLiIiLSjeGf7DIp6ehWwLtheBsw2s2wzG07kxO4qd98OHDCzicEsn+uAJ+KpQ1OeWf9hWx5eRCQlxbu8w91mNpbI0E0x8BUAd19vZkuBDUAFMM/dK4PXzAUeAnKAp4NHu6k7u0ezfUQkjOJK/u7+T42ULQQWxogXAqPjed+k0QkEEUkTobvCtzX07UBE0k3okn9rOu/q8ItIugld8o+LvgKISJoIffLvkZOV7CqIiLS70CX/up33njlZ5GRlxt5ZRCRNhS75x9IjRzc0E5FwUfIXEQkhJX8RkRBS8m+Ga88bCsDkM/snuSYiIomhwe5mGDW4J8WLZiS7GiIiCaOev4hICCn5i4iEkJK/iEgIhS75XxrjpO13rxhFjy46/SEi4RG65H/DBcPqxWaMGcSaO6a2f2VERJIkdMnfYt5GWEQkXOJO/mb2f8xss5mtN7O7o+LzzawoKJsaFR9vZmuDsnuD2zm2Gy3MKSIS5zx/M7sEmAmMcfdyM+sfxEcCs4FRwGBghZmdHtzK8QFgDvAqsByYRjvfylFEJOzi7fnPBRa5ezmAu+8M4jOBJe5e7u5bgSJgQnDD9x7uvtLdHXgYmBVnHUREpIXiTf6nAxea2Wtm9jczOzeI5wElUfuVBrG8YLtuPCYzm2NmhWZWWFZWFmdVI3RXLhGRZgz7mNkKYGCMogXB63sDE4FzgaVmdgrEPKvqjcRjcvfFwGKAgoICpW0RkQRpMvm7+5SGysxsLvCnYAhnlZlVAX2J9OiHRO2aD2wL4vkx4h1Khk4Ki0iai3fY53FgMoCZnQ50BnYBy4DZZpZtZsOBEcAqd98OHDCzicEsn+uAJ+KsQ8JMHTUAgD5ds5NcExGRthVv8v8lcIqZrQOWANd7xHpgKbABeAaYF8z0gchJ4geJnATeQgea6XPLZacD0DtX9/UVkfQW11RPdz8GfKGBsoXAwhjxQmB0PO8bj8bm+XfLjvx3nNa/WzvVRkQkObSgTZT83rn85ksTGDe0d7KrIiLSppT867hwRL9kV0FEpM2Fbm0fERFR8hcRCaXQJX9d4SsiEsLkLyIiSv4iIqGk5C8iEkKhS/66mYuISAiTv4iIKPmLiISSkr+ISAgp+YuIhFDokr8u8hIRCWHyFxERJX8RkVCKK/mb2e/NbHXwKDaz1VFl882syMw2m9nUqPh4M1sblN0b3M5RRETaUbx38vps9baZ/TuwL9geCcwGRgGDgRVmdnpwK8cHgDnAq8ByYBrteCtHfdSIiCRo2CfovV8D/C4IzQSWuHu5u28lcr/eCWY2COjh7ivd3YGHgVmJqENDPlswpC0PLyKSkhI15n8hsMPd3wme5wElUeWlQSwv2K4bj8nM5phZoZkVlpWVtapivy8saXonEZGQaXLYx8xWAANjFC1w9yeC7c9xotcPEGtwxRuJx+Tui4HFAAUFBZqkKSKSIE0mf3ef0li5mXUCPg2MjwqXAtHjLfnAtiCeHyMuIiLtKBHDPlOATe4ePZyzDJhtZtlmNhwYAaxy9+3AATObGJwnuA54ov4hRUSkLcU12ycwm9pDPrj7ejNbCmwAKoB5wUwfgLnAQ0AOkVk+7TbTJ1K39nw3EZGOKe7k7+43NBBfCCyMES8ERsf7viIi0nqhu8JX8/xFREKY/EVERMlfRCSUlPxFREJIyV9EJIRCk/xH9O+W7CqIiHQYoUn+IiJygpK/iEgIhSb5XzFmMAB9u2UnuSYiIsmXiOUdUsJNl57Gly4cTrfs0DRZRKRBoen5m5kSv4hIIDTJX0RETlDyFxEJobQfB3l07vkU7TyY7GqIiHQoaZ/8x5/ch/En90l2NUREOhQN+4iIhFBcyd/MxprZq2a22swKzWxCVNl8Mysys81mNjUqPt7M1gZl9wa3cxQRkXYUb8//buBOdx8LfDd4jpmNJHJ7x1HANOB+M8sMXvMAMIfIfX1HBOUiItKO4k3+DvQItnsC24LtmcASdy93961AETDBzAYBPdx9pbs78DAwK846iIhIC8V7wvdm4Fkz+zGRD5JJQTwPeDVqv9IgdjzYrhuPyczmEPmWwNChQ+OsqoiIVGsy+ZvZCmBgjKIFwKXALe7+qJldA/wCmALEGsf3RuIxuftiYDFAQUFBg/uJiEjLNJn83X1KQ2Vm9jDw9eDpH4AHg+1SYEjUrvlEhoRKg+26cRERaUfxjvlvAz4RbE8G3gm2lwGzzSzbzIYTObG7yt23AwfMbGIwy+c64Ik46yAiIi0U75j/jcBPzawTcJRgfN7d15vZUmADUAHMc/fK4DVzgYeAHODp4NGk119/fZeZvdfKevYFdrXytalA7Ut96d5GtS95To4VtMikm/RmZoXuXpDserQVtS/1pXsb1b6OR1f4ioiEkJK/iEgIhSX5L052BdqY2pf60r2Nal8HE4oxfxERqS0sPX8REYmi5C8iEkJpnfzNbFqwpHSRmd2e7Pq0hJn90sx2mtm6qFgfM3vOzN4J/u0dVZYyS2ib2RAze8HMNprZejP7ehBPi/YBmFkXM1tlZm8FbbwziKdTGzPN7E0zezJ4njZtAzCz4qBuq82sMIilTxvdPS0fQCawBTgF6Ay8BYxMdr1aUP+LgHHAuqjY3cDtwfbtwI+C7ZFB+7KB4UG7M4OyVcD5RNZVehr4ZAdo2yBgXLDdHXg7aENatC+olwHdgu0s4DVgYpq18RvAI8CT6fT7GdW+YqBvnVjatDGde/4TgCJ3f9fdjwFLiCw1nRLc/UVgT53wTODXwfavObEcdkotoe3u2939jWD7ALCRyOquadE+AI+ovnl0VvBw0qSNZpYPzODEel6QJm1rQtq0MZ2Tfx5QEvW80eWjU8QAj6yPRPBv/yDeUFvzaMES2slgZsOAc4j0jNOqfcGwyGpgJ/Ccu6dTG38C3AZURcXSpW3VHPizmb1ukeXlIY3amM43cG/R8tEpLiFLaLc3M+sGPArc7O77GxkKTcn2eWQ9q7Fm1gt4zMxGN7J7yrTRzK4Adrr762Z2cXNeEiPWIdtWxwXuvs3M+gPPmdmmRvZNuTamc8+/oWWlU9mO4Gskwb87g3jKLaFtZllEEv9v3f1PQTht2hfN3fcCfyVyy9J0aOMFwJVmVkxkOHWymf0P6dG2Gu6+Lfh3J/AYkaHktGljOif/fwAjzGy4mXUmck/hZUmuU7yWAdcH29dzYjnslFpCO6jLL4CN7v4fUUVp0T4AM+sX9PgxsxwiNznaRBq00d3nu3u+uw8j8nf1vLt/gTRoWzUz62pm3au3gcuBdaRRG5N+xrktH8B0IjNJtgALkl2fFtb9d8B2Ttz68kvAScBfiNw34S9An6j9FwTt3EzUbAKggMgv7RbgZwRXdSe5bR8n8tV3DbA6eExPl/YF9RoDvBm0cR3w3SCeNm0M6nYxJ2b7pE3biMwSfCt4rK/OH+nURi3vICISQuk87CMiIg1Q8hcRCSElfxGREFLyFxEJISV/EZEQUvIXEQkhJX8RkRD6/19flBbhXai2AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):\n",
    "    observation, reward, done = env.reset(), 0., False\n",
    "    agent.reset(mode=mode)\n",
    "    episode_reward, elapsed_steps = 0., 0\n",
    "    while True:\n",
    "        action = agent.step(observation, reward, done)\n",
    "        if render:\n",
    "            env.render()\n",
    "        if done:\n",
    "            break\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        elapsed_steps += 1\n",
    "        if max_episode_steps and elapsed_steps >= max_episode_steps:\n",
    "            break\n",
    "    agent.close()\n",
    "    return episode_reward, elapsed_steps\n",
    "\n",
    "\n",
    "logging.info('==== train ====')\n",
    "episode_rewards = []\n",
    "for episode in itertools.count():\n",
    "    episode_reward, elapsed_steps = play_episode(env.unwrapped, agent,\n",
    "            max_episode_steps=env._max_episode_steps, mode='train')\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('train episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "    if np.mean(episode_rewards[-200:]) > 8:\n",
    "        break\n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "\n",
    "logging.info('==== test ====')\n",
    "episode_rewards = []\n",
    "for episode in range(100):\n",
    "    episode_reward, elapsed_steps = play_episode(env, agent)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('test episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Show optimal action values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-3.487602</td>\n",
       "      <td>-3.095619</td>\n",
       "      <td>-3.095238</td>\n",
       "      <td>-3.851571</td>\n",
       "      <td>1.364833</td>\n",
       "      <td>-7.108184</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-2.128847</td>\n",
       "      <td>1.632589</td>\n",
       "      <td>-2.144477</td>\n",
       "      <td>0.719892</td>\n",
       "      <td>7.568908</td>\n",
       "      <td>-2.268642</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-3.597061</td>\n",
       "      <td>-3.596466</td>\n",
       "      <td>-2.821674</td>\n",
       "      <td>-2.618224</td>\n",
       "      <td>2.069454</td>\n",
       "      <td>-4.455161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-5.409483</td>\n",
       "      <td>-5.923042</td>\n",
       "      <td>-5.927367</td>\n",
       "      <td>-6.047687</td>\n",
       "      <td>-7.073330</td>\n",
       "      <td>-7.108867</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>-3.102085</td>\n",
       "      <td>-2.970985</td>\n",
       "      <td>-2.883027</td>\n",
       "      <td>-3.022181</td>\n",
       "      <td>-3.600000</td>\n",
       "      <td>-3.636000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>-1.681193</td>\n",
       "      <td>-0.878103</td>\n",
       "      <td>-1.704407</td>\n",
       "      <td>-1.658715</td>\n",
       "      <td>-3.600000</td>\n",
       "      <td>-3.636000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>-3.119792</td>\n",
       "      <td>-3.005599</td>\n",
       "      <td>-3.150313</td>\n",
       "      <td>-3.150968</td>\n",
       "      <td>-5.014447</td>\n",
       "      <td>-3.636000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>499</th>\n",
       "      <td>-0.707040</td>\n",
       "      <td>-0.396000</td>\n",
       "      <td>-0.707040</td>\n",
       "      <td>5.889600</td>\n",
       "      <td>-3.600000</td>\n",
       "      <td>-3.636000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            0         1         2         3         4         5\n",
       "0    0.000000  0.000000  0.000000  0.000000  0.000000  0.000000\n",
       "1   -3.487602 -3.095619 -3.095238 -3.851571  1.364833 -7.108184\n",
       "2   -2.128847  1.632589 -2.144477  0.719892  7.568908 -2.268642\n",
       "3   -3.597061 -3.596466 -2.821674 -2.618224  2.069454 -4.455161\n",
       "4   -5.409483 -5.923042 -5.927367 -6.047687 -7.073330 -7.108867\n",
       "..        ...       ...       ...       ...       ...       ...\n",
       "495  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000\n",
       "496 -3.102085 -2.970985 -2.883027 -3.022181 -3.600000 -3.636000\n",
       "497 -1.681193 -0.878103 -1.704407 -1.658715 -3.600000 -3.636000\n",
       "498 -3.119792 -3.005599 -3.150313 -3.150968 -5.014447 -3.636000\n",
       "499 -0.707040 -0.396000 -0.707040  5.889600 -3.600000 -3.636000\n",
       "\n",
       "[500 rows x 6 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(agent.q)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Show optimal policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>499</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       0    1    2    3    4    5\n",
       "0    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "1    0.0  0.0  0.0  0.0  1.0  0.0\n",
       "2    0.0  0.0  0.0  0.0  1.0  0.0\n",
       "3    0.0  0.0  0.0  0.0  1.0  0.0\n",
       "4    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "..   ...  ...  ...  ...  ...  ...\n",
       "495  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "496  0.0  0.0  1.0  0.0  0.0  0.0\n",
       "497  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "498  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "499  0.0  0.0  0.0  1.0  0.0  0.0\n",
       "\n",
       "[500 rows x 6 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "policy = np.eye(agent.action_n)[agent.q.argmax(axis=-1)]\n",
    "pd.DataFrame(policy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
